{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-29T06:58:48.123910Z",
     "iopub.status.busy": "2022-08-29T06:58:48.123416Z",
     "iopub.status.idle": "2022-08-29T06:58:48.658593Z",
     "shell.execute_reply": "2022-08-29T06:58:48.657816Z",
     "shell.execute_reply.started": "2022-08-29T06:58:48.123858Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x         x1    x2    x3    x4    x5    x6        x7        x8        x9\n",
      "0     0.03  0.48  0.54  0.54  0.15  0.14  0.142857  0.150000  0.142857\n",
      "1     0.03  0.48  0.54  0.54  0.14  0.15  0.142857  0.142857  0.135135\n",
      "2     0.06  0.46  0.54  0.54  0.15  0.17  0.380952  0.400000  0.131579\n",
      "3     0.03  0.48  0.54  0.54  0.16  0.13  0.150000  0.150000  0.131579\n",
      "4     0.00  0.48  0.55  0.55  0.15  0.14  0.380952  0.380952  0.135135\n",
      "...    ...   ...   ...   ...   ...   ...       ...       ...       ...\n",
      "6175  0.42  0.40  0.31  0.57  0.19  0.06 -1.000000 -1.000000 -1.000000\n",
      "6176  0.42  0.40  0.31  0.58  0.21  0.04 -1.000000 -1.000000 -1.000000\n",
      "6177  0.42  0.40  0.31  0.59  0.19  0.04 -1.000000 -1.000000 -1.000000\n",
      "6178  0.04  0.36  0.43  0.56  0.11  0.19 -1.000000 -1.000000 -1.000000\n",
      "6179  0.83  0.40  0.56  1.62  0.00  0.51 -1.000000 -1.000000 -1.000000\n",
      "\n",
      "[6180 rows x 9 columns]\n",
      "训练集和测试集 shape (4326, 9) (4326,) (1854, 9) (1854,)\n",
      "RandomForestClassifier()\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.90      0.94      0.92       759\n",
      "           1       0.89      0.77      0.82       282\n",
      "           2       0.93      0.93      0.93       813\n",
      "\n",
      "    accuracy                           0.91      1854\n",
      "   macro avg       0.90      0.88      0.89      1854\n",
      "weighted avg       0.91      0.91      0.91      1854\n",
      "\n",
      "[[714  15  30]\n",
      " [ 35 216  31]\n",
      " [ 44  13 756]]\n",
      "Accuracy: 90.94%\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from sklearn.ensemble import RandomForestClassifier  # 导入sklearn库的RandomForestClassifier函数\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import metrics  # 分类结果评价函数\n",
    "import joblib\n",
    "data = pd.read_csv('data/all_data.txt', names=['x1','x2','x3','x4','x5','x6','x7','x8','x9','y'], encoding='utf-8',sep=',')\n",
    "# print(data)\n",
    "\n",
    "\n",
    "x = data[['x1', 'x2', 'x3','x4','x5','x6','x7','x8','x9']]  # 特征\n",
    "y = data['y']  # 标签\n",
    "print('x',x)\n",
    "# print('y',y)\n",
    "# 划分数据集8\n",
    "x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=0, train_size=0.7)\n",
    "print('训练集和测试集 shape', x_train.shape, y_train.shape, x_test.shape, y_test.shape)\n",
    "\n",
    "model = RandomForestClassifier()  # 实例化模型RandomForestClassifier\n",
    "model.fit(x_train, y_train)  # 在训练集上训练模型\n",
    "joblib.dump(model,\"test_model_randomforest.m\")\n",
    "print(model)  # 输出模型RandomForestClassifier\n",
    "\n",
    "# 在测试集上测试模型\n",
    "expected = y_test  # 测试样本的期望输出\n",
    "predicted = model.predict(x_test)  # 测试样本预测\n",
    "\n",
    "# 输出结果\n",
    "print(metrics.classification_report(expected, predicted))  # 输出结果，精确度、召回率、f-1分数\n",
    "print(metrics.confusion_matrix(expected, predicted))  # 混淆矩阵\n",
    "\n",
    "# auc = metrics.roc_auc_score(y_test, predicted)\n",
    "accuracy = metrics.accuracy_score(y_test, predicted)  # 求精度\n",
    "print(\"Accuracy: %.2f%%\" % (accuracy * 100.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "naas"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
